{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[View the runnable example on GitHub](https://github.com/intel-analytics/BigDL/tree/main/python/nano/tutorial/notebook/training/tensorflow/tensorflow_training_bf16.ipynb)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Use BFloat16 Mixed Precision for TensorFlow Keras Training"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> ⚠️ **Warning**\n",
    "> \n",
    "> This feature is under quick iteration, usage may be changed later."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;*The following example is adapted from* https://www.tensorflow.org/guide/mixed_precision"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Brain Floating Point Format (BFloat16) is a custom 16-bit floating point format designed for machine learning. BFloat16 is comprised of 1 sign bit, 8 exponent bits, and 7 mantissa bits. With the same number of exponent bits, BFloat16 has the same dynamic range as FP32, but requires only half the memory usage.\n",
    "\n",
    "BFloat16 Mixed Precison combines BFloat16 and FP32 during training, which could lead to increased performance and reduced memory usage. Compared to FP16 mixed precison, BFloat16 mixed precision has better numerical stability.\n",
    "\n",
    "BigDL-Nano provides a TensorFlow patch (`bigdl.nano.tf.patch_tensorflow`) integrated with multiple optimizations. You could apply `patch_tensorflow(precision='mixed_bfloat16')` to easily use BFloat16 mixed precision for training."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "nbsphinx": "hidden"
   },
   "source": [
    "To use BFloat16 mixed precision in TensorFlow Keras Training, you need to install BigDL-Nano for TensorFlow first:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "nbsphinx": "hidden"
   },
   "outputs": [],
   "source": [
    "# install the nightly-built version of bigdl-nano for tensorflow;\n",
    "# intel-tensorflow will be installed at the meantime with intel's oneDNN optimizations enabled by default\n",
    "!pip install --pre --upgrade bigdl-nano[tensorflow]\n",
    "!source bigdl-nano-init  # set environment variables"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> 📝 **Note**\n",
    ">\n",
    "> Before starting your TensorFlow Keras application, it is highly recommended to run `source bigdl-nano-init` to set several environment variables based on your current hardware. Empirically, these variables will bring big performance increase for most TensorFlow Keras applications on training workloads."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "nbsphinx": "hidden"
   },
   "source": [
    "> ⚠️ **Warning**\n",
    "> \n",
    "> For Jupyter Notebook users, we recommend to run the commands above, especially `source bigdl-nano-init` before jupyter kernel is started, or some of the optimizations may not take effect."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> ⚠️ **Warning**\n",
    ">\n",
    "> BigDL-Nano will enable intel's oneDNN optimizations by default. oneDNN BFloat16 are only supported on platforms with AVX512 instruction set.\n",
    ">\n",
    "> Platforms without hardware acceleration for BFloat16 could lead to bad BFloat16 training performance."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Patch TensorFlow"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To conduct BFloat16 mixed precision training, the first thing (and the only thing for most cases) is to **import** `patch_tensorflow` **from BigDL-Nano, and call it with** `precision` **set to** `'mixed_bfloat16'`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from bigdl.nano.tf import patch_tensorflow\n",
    "\n",
    "patch_tensorflow(precision='mixed_bfloat16')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> 📝 **Note**\n",
    ">\n",
    "> By patching TensorFlow with `'mixed_bfloat16'` as `precision`, a global `'mixed_bfloat16'` dtype policy will be set, which will be treated as the default policy for every Keras layer created after the patching.\n",
    ">\n",
    "> The layer set with `'mixed_bfloat16'` dtype policy will conduct computation in BFloat16, while save its variables in Float32 data format."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build Model"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's take the [MNIST digits classification dataset](https://keras.io/api/datasets/mnist/) as an example, and suppose that we would like to create a model that will be trained on it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers, Model\n",
    "\n",
    "inputs = keras.Input(shape=(784,), name='digits')\n",
    "\n",
    "dense1 = layers.Dense(units=64, activation='relu', name='dense_1')\n",
    "x = dense1(inputs)\n",
    "dense2 = layers.Dense(units=64, activation='relu', name='dense_2')\n",
    "x = dense2(x)\n",
    "\n",
    "# Note that we separate the Dense layer and the softmax layer\n",
    "# and set 'float32' as the dtype policy here for the last softmax layer\n",
    "x = layers.Dense(10, name='dense_logits')(x)\n",
    "outputs = layers.Activation('softmax', dtype='float32', name='predictions')(x)\n",
    "print(f'Output dtype: {outputs.dtype.name}')\n",
    "\n",
    "model = Model(inputs=inputs, outputs=outputs)\n",
    "model.compile(loss='sparse_categorical_crossentropy',\n",
    "              optimizer=keras.optimizers.RMSprop(),\n",
    "              metrics=['accuracy'])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> 📝 **Note**\n",
    ">\n",
    "> The dtype policy `'float32'` we set here will override the global `'mixed_bfloat16'` policy for the last layer, aiming at a Float32 output tensor for the model.\n",
    ">\n",
    "> It is suggested to override the last layer of the model to have `'float32'` dtype policy, so that numeric issues caused by dtype mismatch could be avoided when the output tensor flowing to loss."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train Model"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model.fit"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When conduct training using `Model.fit`, there is nothing special you need to do for BFloat16 mixed precision training:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create train/test data\n",
    "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n",
    "x_train = x_train.reshape(60000, 784).astype('float32') / 255\n",
    "x_test = x_test.reshape(10000, 784).astype('float32') / 255\n",
    "\n",
    "# train with fit\n",
    "model.fit(x_train, y_train,\n",
    "          batch_size=8192,\n",
    "          epochs=10,\n",
    "          validation_split=0.2)\n",
    "                    \n",
    "test_scores = model.evaluate(x_test, y_test, verbose=2)\n",
    "print('Test loss:', test_scores[0])\n",
    "print('Test accuracy:', test_scores[1])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Custom training loop"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you create a custom traing loop, you should also wrap the train/test step function with the `@nano_bf16` decorator:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from bigdl.nano.tf.keras import nano_bf16 # import the decorator\n",
    "\n",
    "# create loss function, optimizer, and train/test datasets\n",
    "optimizer = keras.optimizers.RMSprop()\n",
    "loss_object = tf.keras.losses.SparseCategoricalCrossentropy()\n",
    "train_dataset = (tf.data.Dataset.from_tensor_slices((x_train, y_train))\n",
    "                 .shuffle(10000).batch(8192))\n",
    "test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(8192)\n",
    "\n",
    "# define train/test step\n",
    "@nano_bf16 # apply the decorator to the train_step\n",
    "@tf.function\n",
    "def train_step(x, y):\n",
    "  with tf.GradientTape() as tape:\n",
    "    predictions = model(x)\n",
    "    loss = loss_object(y, predictions)\n",
    "  gradients = tape.gradient(loss, model.trainable_variables)\n",
    "  optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
    "  return loss\n",
    "\n",
    "@nano_bf16 # apply the decorator to the test_step\n",
    "@tf.function\n",
    "def test_step(x):\n",
    "  return model(x, training=False)\n",
    "\n",
    "# conduct the training\n",
    "for epoch in range(10):\n",
    "  epoch_loss_avg = tf.keras.metrics.Mean()\n",
    "  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')\n",
    "  for x, y in train_dataset:\n",
    "    loss = train_step(x, y)\n",
    "    epoch_loss_avg(loss)\n",
    "  for x, y in test_dataset:\n",
    "    predictions = test_step(x)\n",
    "    test_accuracy.update_state(y, predictions)\n",
    "  print('Epoch {}: loss={}, test accuracy={}'.format(epoch, epoch_loss_avg.result(), test_accuracy.result()))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> 📝 **Note**\n",
    ">\n",
    "> If you do not set `'float32'` dtype policy for the last layer of the model, and thus have BFloat16 tensor as model output, `@nano_bf16` could be a compensate to avoid dtype mismatch error, which casts the input tensor and numpy ndarray of the decorated train/test step to be BFloat16.\n",
    ">\n",
    "> You could try to apply the `@nano_bf16` decorator to other function during the custom training loop if you meet dtype mismatch error."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## (Optional) Unpatch TensorFlow"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you want to go back to Float32 training again, you could simply call the `unpatch_tensorflow` function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from bigdl.nano.tf import unpatch_tensorflow\n",
    "\n",
    "unpatch_tensorflow()\n",
    "\n",
    "print(f\"model's dtype policy is still: {model.dtype_policy.name}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> 📝 **Note**\n",
    ">\n",
    "> The model created after the `unpatch_tensorflow` function will have `'float32'` as its global dtype policy. However, the model created before, under `patch_tensorflow(precision='mixed_bfloat16')`, will still has layers with `'mixed_bfloat16'` as dtype policy."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> 📚 **Related Readings**\n",
    "> \n",
    "> - [How to install BigDL-Nano](https://bigdl.readthedocs.io/en/latest/doc/Nano/Overview/install.html)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "nano-tf",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.7.13 (default, Mar 29 2022, 02:18:16) \n[GCC 7.5.0]"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "402532f56d486e9f832908f31130bbdf12bd8cb099dfb226783aa2c6b1479100"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
